Submitted by:
| # | Name | Id | |
|---|---|---|---|
| Student 1 | [Lior Sherman] | [307932277] | [lior.s@campus.technion.ac.il] |
| Student 2 | [Tal Rosenseweig] | [307965806] | [tal.r@campus.technion.ac.il] |
In this assignment we'll learn to generate text with a deep multilayer RNN network based on GRU cells. Then we'll focus our attention on image generation and implement two different generative models: A variational autoencoder and a generative adversarial network.
hw1, hw2, etc).
You can of course use any editor or IDE to work on these files.In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import numpy as np
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda
Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
pathlib.Path(out_path).mkdir(exist_ok=True)
out_filename = os.path.join(out_path, os.path.basename(url))
if os.path.isfile(out_filename) and not force:
print(f'Corpus file {out_filename} exists, skipping download.')
else:
print(f'Downloading {url}...')
with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
shutil.copyfileobj(response, out_file)
print(f'Saved to {out_filename}.')
return out_filename
corpus_path = download_corpus()
Corpus file /home/lior.s/.pytorch-datasets/shakespeare.txt exists, skipping download.
Load the text into memory and print a snippet:
with open(corpus_path, 'r', encoding='utf-8') as f:
corpus = f.read()
print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL
by William Shakespeare
Dramatis Personae
KING OF FRANCE
THE DUKE OF FLORENCE
BERTRAM, Count of Rousillon
LAFEU, an old lord
PAROLLES, a follower of Bertram
TWO FRENCH LORDS, serving with Bertram
STEWARD, Servant to the Countess of Rousillon
LAVACHE, a clown and Servant to the Countess of Rousillon
A PAGE, Servant to the Countess of Rousillon
COUNTESS OF ROUSILLON, mother to Bertram
HELENA, a gentlewoman protected by the Countess
A WIDOW OF FLORENCE.
DIANA, daughter to the Widow
VIOLENTA, neighbour and friend to the Widow
MARIANA, neighbour and friend to the Widow
Lords, Officers, Soldiers, etc., French and Florentine
SCENE:
Rousillon; Paris; Florence; Marseilles
ACT I. SCENE 1.
Rousillon. The COUNT'S palace
Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black
COUNTESS. In delivering my son from me, I bury a second husband.
BERTRAM. And I in going, madam, weep o'er my father's death anew;
but I must attend his Majesty's command, to whom I am now in
ward, evermore in subjection.
LAFEU. You shall find of the King a husband, madam; you, sir, a
father. He that so generally is at all times good must of
The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.
TODO: Implement the char_maps() function in the hw3/charnn.py module.
import hw3.charnn as charnn
char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)
test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'\n': 0, ' ': 1, '!': 2, '"': 3, '$': 4, '&': 5, "'": 6, '(': 7, ')': 8, ',': 9, '-': 10, '.': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21, ':': 22, ';': 23, '<': 24, '?': 25, 'A': 26, 'B': 27, 'C': 28, 'D': 29, 'E': 30, 'F': 31, 'G': 32, 'H': 33, 'I': 34, 'J': 35, 'K': 36, 'L': 37, 'M': 38, 'N': 39, 'O': 40, 'P': 41, 'Q': 42, 'R': 43, 'S': 44, 'T': 45, 'U': 46, 'V': 47, 'W': 48, 'X': 49, 'Y': 50, 'Z': 51, '[': 52, ']': 53, '_': 54, 'a': 55, 'b': 56, 'c': 57, 'd': 58, 'e': 59, 'f': 60, 'g': 61, 'h': 62, 'i': 63, 'j': 64, 'k': 65, 'l': 66, 'm': 67, 'n': 68, 'o': 69, 'p': 70, 'q': 71, 'r': 72, 's': 73, 't': 74, 'u': 75, 'v': 76, 'w': 77, 'x': 78, 'y': 79, 'z': 80, '}': 81, '\ufeff': 82}
Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.
TODO: Implement the remove_chars() function in the hw3/charnn.py module.
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')
# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 34 chars
The next thing we need is an embedding of the chracters.
An embedding is a representation of each token from the sequence as a tensor.
For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented
as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index
corresponding to that specific char.
TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.
# Wrap the actual embedding functions for calling convenience
def embed(text):
return charnn.chars_to_onehot(text, char_to_idx)
def unembed(embedding):
return charnn.onehot_to_chars(embedding, idx_to_char)
text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))
test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0]], dtype=torch.int8)
We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.
We will split our corpus into shorter sequences of length S chars (see question below).
Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence.
For each sample, we'll also need a label. This is simply another sequence, shifted by one char so that the label of each char is the next char in the corpus.
TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)
# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')
# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))
# Test content
for _ in range(1000):
# random sample
i = np.random.randint(num_samples, size=(1,))[0]
# Compare to corpus
test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
# Compare to labels
sample_text = unembed(samples[i])
label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
samples shape: torch.Size([99182, 64, 78]) labels shape: torch.Size([99182, 64])
Let's print a few consecutive samples. You should see that the text continues between them.
import re
import random
i = random.randrange(num_samples-5)
for i in range(i, i+5):
s = re.sub(r'\s+', ' ', unembed(samples[i])).strip()
print(f'sample [{i}]:\n\t{s}')
sample [77667]: an end. I am resolv'd that thou shalt spend some time W sample [77668]: ith Valentinus in the Emperor's court; What maintenance he f sample [77669]: rom his friends receives, Like exhibition thou shalt have fr sample [77670]: om me. To-morrow be in readiness to go- Excuse it not, f sample [77671]: or I am peremptory. PROTEUS. My lord, I cannot be so soon prov
As usual, instead of feeding one sample at a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.
An important nuance is that we need the batches to be contiguous, i.e. sample $k$ in batch $j$ should continue sample $k$ from batch $j-1$. The following figure illustrates this:

If we naïvely take consecutive samples into batches, e.g. [0,1,...,B-1], [B,B+1,...,2B-1] and so on, we won't have contiguous
sequences at the same index between adjacent batches.
To accomplish this we need to tell our DataLoader which samples to combine together into one batch.
We do this by implementing a custom PyTorch Sampler, and providing it to our DataLoader.
TODO: Implement the SequenceBatchSampler class in the hw3/charnn.py module.
from hw3.charnn import SequenceBatchSampler
sampler = SequenceBatchSampler(dataset=range(32), batch_size=10)
sampler_idx = list(sampler)
print('sampler_idx =\n', sampler_idx)
# Test the Sampler
test.assertEqual(len(sampler_idx), 30)
batch_idx = np.array(sampler_idx).reshape(-1, 10)
for k in range(10):
test.assertEqual(np.diff(batch_idx[:, k], n=2).item(), 0)
sampler_idx = [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29]
Even though we're working with sequences, we can still use the standard PyTorch Dataset/DataLoader combo.
For the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label)
from the samples and labels tensors we created above.
The DataLoader will be provided with our custom Sampler so that it generates appropriate batches.
import torch.utils.data
# Create DataLoader returning batches of samples.
batch_size = 32
ds_corpus = torch.utils.data.TensorDataset(samples, labels)
sampler_corpus = SequenceBatchSampler(ds_corpus, batch_size)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, sampler=sampler_corpus, shuffle=False)
Let's see what that gives us:
print(f'num batches: {len(dl_corpus)}')
x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch of samples: {x0.shape}')
print(f'shape of a batch of labels: {y0.shape}')
num batches: 3100 shape of a batch of samples: torch.Size([32, 64, 78]) shape of a batch of labels: torch.Size([32, 64])
Now lets look at the same sample index from multiple batches taken from our corpus.
# Check that sentences in in same index of different batches complete each other.
k = random.randrange(batch_size)
for j, (X, y) in enumerate(dl_corpus,):
print(f'=== batch {j}, sample {k} ({X[k].shape}): ===')
s = re.sub(r'\s+', ' ', unembed(X[k])).strip()
print(f'\t{s}')
if j==4: break
=== batch 0, sample 20 (torch.Size([64, 78])): === e royal blood For thee to slaughter. For my daughters, Richa === batch 1, sample 20 (torch.Size([64, 78])): === rd, They shall be praying nuns, not weeping queens; And === batch 2, sample 20 (torch.Size([64, 78])): === therefore level not to hit their lives. KING RICHARD. You have === batch 3, sample 20 (torch.Size([64, 78])): === a daughter call'd Elizabeth. Virtuous and fair, royal and g === batch 4, sample 20 (torch.Size([64, 78])): === racious. QUEEN ELIZABETH. And must she die for this? O, let he
Finally, our data set is ready so we can focus on our model.
We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.
The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.
Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as
$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}
. $$
The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$
and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$
Notes:
Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).
Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (i.e., on the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.
TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.
Notes:
in_dim = vocab_len
h_dim = 256
n_layers = 3
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)
# Test forward pass
y, h = model(x0.to(dtype=torch.float))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')
test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2)
MultilayerGRU( (w_xz_0): Linear(in_features=78, out_features=256, bias=False) (w_hz_0): Linear(in_features=256, out_features=256, bias=True) (w_xr_0): Linear(in_features=78, out_features=256, bias=False) (w_hr_0): Linear(in_features=256, out_features=256, bias=True) (w_xg_0): Linear(in_features=78, out_features=256, bias=False) (w_hg_0): Linear(in_features=256, out_features=256, bias=True) (w_xz_1): Linear(in_features=256, out_features=256, bias=False) (w_hz_1): Linear(in_features=256, out_features=256, bias=True) (w_xr_1): Linear(in_features=256, out_features=256, bias=False) (w_hr_1): Linear(in_features=256, out_features=256, bias=True) (w_xg_1): Linear(in_features=256, out_features=256, bias=False) (w_hg_1): Linear(in_features=256, out_features=256, bias=True) (w_xz_2): Linear(in_features=256, out_features=256, bias=False) (w_hz_2): Linear(in_features=256, out_features=256, bias=True) (w_xr_2): Linear(in_features=256, out_features=256, bias=False) (w_hr_2): Linear(in_features=256, out_features=256, bias=True) (w_xg_2): Linear(in_features=256, out_features=256, bias=False) (w_hg_2): Linear(in_features=256, out_features=256, bias=True) (dropout): Dropout(p=0, inplace=False) (out): Linear(in_features=256, out_features=78, bias=True) ) y.shape=torch.Size([32, 64, 78]) h.shape=torch.Size([32, 3, 256])
Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t, \vec{h}_t).$$
Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.
The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.
To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$
A low $T$ will result in less uniform distributions and vice-versa.
TODO: Implement the hot_softmax() function in the hw3/charnn.py module.
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))
for t in reversed([0.3, 0.5, 1.0, 100]):
ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()
uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))
TODO: Implement the generate_from_model() function in the hw3/charnn.py module.
for _ in range(3):
text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
print(text)
test.assertEqual(len(text), 50)
foobarob"1hTp4-).mxt97?02vG1-9VVuB1 M&!)7qMU0DNxtm foobarx;iOcHUS7k171jB: F0[7(fum8mywOT!yFe]ZxD!wFo] foobar'[xrktrx!O?NJj;t32xQYCP,(MC!jdK09Y:vW&scdRl"
To train this model, we'll calculate the loss at each time step by comparing the predicted char to
the actual char from our label. We can use cross entropy since per char it's similar to a classification problem.
We'll then sum the losses over the sequence and back-propagate the gradients though time.
Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times,
so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.
As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.
For a generative model such as this, overfitting is slightly trickier than for for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.
Let's create a tiny dataset to memorize.
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
batch_size_ss = 1
sampler_ss = SequenceBatchSampler(ds_corpus_ss, batch_size=batch_size_ss)
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size_ss, sampler=sampler_ss, shuffle=False)
# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":
TRAM. What would you have?
HELENA. Something; and scarce so much; nothing, indeed.
I would not tell you what I would, my lord.
Faith, yes:
Strangers and foes do sunder and not kiss.
BERTRAM. I pray you, stay not, but in haste to horse.
HE
Now let's implement the first part of our training code.
TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module.
You must think about how to correctly handle the hidden state of the model between batches and epochs for this specific task (i.e. text generation).
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer
torch.manual_seed(42)
lr = 0.01
num_epochs = 500
in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)
for epoch in range(num_epochs):
epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
# Every X epochs, we'll generate a sequence starting from the first char in the first sequence
# to visualize how/if/what the model is learning.
if epoch == 0 or (epoch+1) % 25 == 0:
avg_loss = np.mean(epoch_result.losses)
accuracy = np.mean(epoch_result.accuracy)
print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
generated_sequence = charnn.generate_from_model(model, subset_text[0],
seq_len*(subset_end-subset_start),
(char_to_idx,idx_to_char), T=0.1)
# Stop if we've successfully memorized the small dataset.
print(generated_sequence)
if generated_sequence == subset_text:
break
# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.940, Accuracy = 17.58%
Twn n
Epoch #25: Avg. loss = 0.285, Accuracy = 94.92%
TRAM. What would you have?
HELENA. Something; and scarce so much; not indeed.
Faith, yes:
Faith, yes:
Faith, yes:
Faith, yes:
Faith, yes:
Faith, yes:
Faith, yes:
Faith, yes:
Fait, yes:
Faith, yes:
Faith, yes:
Epoch #50: Avg. loss = 0.008, Accuracy = 100.00%
TRAM. What would you have?
HELENA. Something; and scarce so much; nothing, indeed.
I would not tell you what I would, my lord.
Faith, yes:
Strangers and foes do sunder and not kiss.
BERTRAM. I pray you, stay not, but in haste to horse.
HE
OK, so training works - we can memorize a short sequence. We'll now train a much larger model on our large dataset. You'll need a GPU for this part.
First, lets set up our dataset and models for training. We'll split our corpus into 90% train and 10% test-set. Also, we'll use a learning-rate scheduler to control the learning rate during training.
TODO: Set the hyperparameters in the part1_rnn_hyperparams() function of the hw3/answers.py module.
from hw3.answers import part1_rnn_hyperparams
hp = part1_rnn_hyperparams()
print('hyperparams:\n', hp)
### Dataset definition
vocab_len = len(char_to_idx)
batch_size = hp['batch_size']
seq_len = hp['seq_len']
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
sampler_train = SequenceBatchSampler(ds_train, batch_size)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size, shuffle=False, sampler=sampler_train, drop_last=True)
ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
sampler_test = SequenceBatchSampler(ds_test, batch_size)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size, shuffle=False, sampler=sampler_test, drop_last=True)
print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test: {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')
### Training definition
in_dim = out_dim = vocab_len
checkpoint_file = 'checkpoints/rnn'
num_epochs = 50
early_stopping = 2
model = charnn.MultilayerGRU(in_dim, hp['h_dim'], out_dim, hp['n_layers'], hp['dropout'])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=hp['learn_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', factor=hp['lr_sched_factor'], patience=hp['lr_sched_patience'], verbose=True
)
trainer = RNNTrainer(model, loss_fn, optimizer, device)
hyperparams:
{'batch_size': 256, 'seq_len': 64, 'h_dim': 512, 'n_layers': 3, 'dropout': 0.5, 'learn_rate': 0.001, 'lr_sched_factor': 0.5, 'lr_sched_patience': 2}
Train: 348 batches, 5701632 chars
Test: 38 batches, 622592 chars
The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.
Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.
TODO:
fit() method of the Trainer class. You can reuse the relevant implementation parts from HW2, but make sure to implement early stopping and checkpoints.test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.checkpoints/rnn_final.pt.
This will cause the block to skip training and instead load your saved model when running the homework submission script.
Note that your submission zip file will not include the checkpoint file. This is OK.from cs236781.plot import plot_fit
def post_epoch_fn(epoch, train_res, test_res, verbose):
# Update learning rate
scheduler.step(test_res.accuracy)
# Sample from model to show progress
if verbose:
start_seq = "ACT I."
generated_sequence = charnn.generate_from_model(
model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
)
print(generated_sequence)
# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
saved_state = torch.load(checkpoint_file_final, map_location=device)
model.load_state_dict(saved_state['model_state'])
else:
try:
# Print pre-training sampling
print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))
fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=None,
post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
checkpoints=checkpoint_file, print_every=1)
fig, axes = plot_fit(fit_res)
except KeyboardInterrupt as e:
print('\n *** Training interrupted by user')
ACT I.KqXYZ2t,])GPnSsfeBoLy36v""NsTJnksQ2,7RWw01lLOe
x9uxjldLo!vT-
aM(QFSL2tKz ]1Qg2!ToZ(1IYFPRpFWhV
--- EPOCH 1/50 ---
train_batch (Avg. Loss 2.159, Accuracy 40.3): 100%|██████████| 348/348 [01:27<00:00, 3.98it/s]
test_batch (Avg. Loss 1.805, Accuracy 45.6): 100%|██████████| 38/38 [00:03<00:00, 10.12it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 1
ACT I. What you may for the ending the enter
Whith the world have shall with a will evers the st
--- EPOCH 2/50 ---
train_batch (Avg. Loss 1.611, Accuracy 52.4): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.569, Accuracy 52.0): 100%|██████████| 38/38 [00:03<00:00, 10.15it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 2
ACT I. Sir, I do the world be the own too
good like a light of the came of his own to speak of t
--- EPOCH 3/50 ---
train_batch (Avg. Loss 1.472, Accuracy 56.2): 100%|██████████| 348/348 [01:27<00:00, 3.98it/s]
test_batch (Avg. Loss 1.468, Accuracy 54.8): 100%|██████████| 38/38 [00:03<00:00, 10.14it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 3
ACT I. So I may be stand.
SECOND MORTANDE. I know the morter and sound the bear him,
Shall I k
--- EPOCH 4/50 ---
train_batch (Avg. Loss 1.401, Accuracy 58.1): 100%|██████████| 348/348 [01:26<00:00, 4.01it/s]
test_batch (Avg. Loss 1.420, Accuracy 56.2): 100%|██████████| 38/38 [00:03<00:00, 10.08it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 4
ACT I. SCENE I.
The King of England with the letter of Somerset.
--- EPOCH 5/50 ---
train_batch (Avg. Loss 1.356, Accuracy 59.3): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.384, Accuracy 57.0): 100%|██████████| 38/38 [00:03<00:00, 9.97it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 5
ACT I. Still and I
CALIBAN. And I shall not be no love.
CAESAR. What was the way of Hector?
SE
--- EPOCH 6/50 ---
train_batch (Avg. Loss 1.323, Accuracy 60.0): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.359, Accuracy 57.7): 100%|██████████| 38/38 [00:03<00:00, 10.09it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 6
ACT I. Scene I
TIMON. You may prove me his tale of in the constable.
--- EPOCH 7/50 ---
train_batch (Avg. Loss 1.299, Accuracy 60.7): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.341, Accuracy 58.3): 100%|██████████| 38/38 [00:03<00:00, 10.11it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 7
ACT I. S,epherd and Lord Margaret,
--- EPOCH 8/50 ---
train_batch (Avg. Loss 1.280, Accuracy 61.1): 100%|██████████| 348/348 [01:26<00:00, 4.00it/s]
test_batch (Avg. Loss 1.321, Accuracy 58.7): 100%|██████████| 38/38 [00:03<00:00, 10.22it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 8
ACT I. SCENE I.
A son and Marcius Tominetal.
CASSIUS. I am a good fair hand.
Bene. The great ar
--- EPOCH 9/50 ---
train_batch (Avg. Loss 1.263, Accuracy 61.5): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.310, Accuracy 59.0): 100%|██████████| 38/38 [00:03<00:00, 10.14it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 9
ACT I. Stay, sir,
I say the care to stand my house with the fair person
That see him shall h
--- EPOCH 10/50 ---
train_batch (Avg. Loss 1.250, Accuracy 61.9): 100%|██████████| 348/348 [01:27<00:00, 3.99it/s]
test_batch (Avg. Loss 1.300, Accuracy 59.2): 100%|██████████| 38/38 [00:03<00:00, 10.03it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 10
ACT I.
Exit SERVANT
--- EPOCH 11/50 ---
train_batch (Avg. Loss 1.239, Accuracy 62.2): 100%|██████████| 348/348 [01:26<00:00, 4.00it/s]
test_batch (Avg. Loss 1.287, Accuracy 59.5): 100%|██████████| 38/38 [00:03<00:00, 10.12it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 11
ACT I. SCENE I.
Enter a GREMIO
--- EPOCH 12/50 ---
train_batch (Avg. Loss 1.229, Accuracy 62.4): 100%|██████████| 348/348 [01:26<00:00, 4.00it/s]
test_batch (Avg. Loss 1.280, Accuracy 59.6): 100%|██████████| 38/38 [00:03<00:00, 10.15it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 12
ACT I.
Exeunt.
Scene III.
Deniod. En
--- EPOCH 13/50 ---
train_batch (Avg. Loss 1.220, Accuracy 62.7): 100%|██████████| 348/348 [01:31<00:00, 3.79it/s]
test_batch (Avg. Loss 1.276, Accuracy 59.9): 100%|██████████| 38/38 [00:04<00:00, 9.44it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 13
ACT I.
Exeunt
SCENE III.
Consen
--- EPOCH 14/50 ---
train_batch (Avg. Loss 1.212, Accuracy 62.8): 100%|██████████| 348/348 [01:32<00:00, 3.75it/s]
test_batch (Avg. Loss 1.267, Accuracy 60.1): 100%|██████████| 38/38 [00:04<00:00, 9.38it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 14
ACT I. Still in the cause.
ANTIPHOLUS OF SYRACUSE. I have to do not so.
DUKE. What is the matter?
--- EPOCH 15/50 ---
train_batch (Avg. Loss 1.205, Accuracy 63.0): 100%|██████████| 348/348 [01:32<00:00, 3.75it/s]
test_batch (Avg. Loss 1.262, Accuracy 60.2): 100%|██████████| 38/38 [00:04<00:00, 9.37it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 15
ACT I. Speak.
Exeunt
--- EPOCH 16/50 ---
train_batch (Avg. Loss 1.199, Accuracy 63.2): 100%|██████████| 348/348 [01:33<00:00, 3.74it/s]
test_batch (Avg. Loss 1.259, Accuracy 60.3): 100%|██████████| 38/38 [00:04<00:00, 9.44it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 16
ACT I.
--- EPOCH 17/50 ---
train_batch (Avg. Loss 1.193, Accuracy 63.3): 100%|██████████| 348/348 [01:32<00:00, 3.75it/s]
test_batch (Avg. Loss 1.255, Accuracy 60.4): 100%|██████████| 38/38 [00:03<00:00, 9.55it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 17
ACT I. Scene I.
But wherefore be the tempest of the heart?
What says he?
MENENIUS. No, I c
--- EPOCH 18/50 ---
train_batch (Avg. Loss 1.188, Accuracy 63.5): 100%|██████████| 348/348 [01:31<00:00, 3.80it/s]
test_batch (Avg. Loss 1.248, Accuracy 60.6): 100%|██████████| 38/38 [00:03<00:00, 10.15it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 18
ACT I. Scene I
--- EPOCH 19/50 ---
train_batch (Avg. Loss 1.183, Accuracy 63.6): 100%|██████████| 348/348 [01:26<00:00, 4.00it/s]
test_batch (Avg. Loss 1.247, Accuracy 60.6): 100%|██████████| 38/38 [00:03<00:00, 10.07it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 19
ACT I.
Enter CORIOLANUS
CASSIO. My lord, th
--- EPOCH 20/50 ---
train_batch (Avg. Loss 1.178, Accuracy 63.7): 100%|██████████| 348/348 [01:27<00:00, 3.98it/s]
test_batch (Avg. Loss 1.241, Accuracy 60.8): 100%|██████████| 38/38 [00:03<00:00, 10.16it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 20
ACT I.
Exeunt
SCENE 2.
Britain.
Enter a
--- EPOCH 21/50 ---
train_batch (Avg. Loss 1.174, Accuracy 63.8): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.238, Accuracy 60.9): 100%|██████████| 38/38 [00:03<00:00, 10.13it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 21
ACT I.
Exeunt.
Scene III.
Prince. The tric
--- EPOCH 22/50 ---
train_batch (Avg. Loss 1.170, Accuracy 64.0): 100%|██████████| 348/348 [01:27<00:00, 3.98it/s]
test_batch (Avg. Loss 1.236, Accuracy 61.0): 100%|██████████| 38/38 [00:03<00:00, 10.10it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 22
ACT I. Say the cause,
Do you not love?
ANTONY. You have not so better.
If you will pay you
--- EPOCH 23/50 ---
train_batch (Avg. Loss 1.167, Accuracy 64.0): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.230, Accuracy 61.1): 100%|██████████| 38/38 [00:03<00:00, 9.55it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 23
ACT I.
Exeunt
SCENE II.
A street.
Enter
--- EPOCH 24/50 ---
train_batch (Avg. Loss 1.163, Accuracy 64.1): 100%|██████████| 348/348 [01:32<00:00, 3.75it/s]
test_batch (Avg. Loss 1.228, Accuracy 61.2): 100%|██████████| 38/38 [00:04<00:00, 9.32it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 24
ACT I.
[Takes him.
--- EPOCH 25/50 ---
train_batch (Avg. Loss 1.160, Accuracy 64.2): 100%|██████████| 348/348 [01:32<00:00, 3.74it/s]
test_batch (Avg. Loss 1.228, Accuracy 61.2): 100%|██████████| 38/38 [00:04<00:00, 9.46it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 25
ACT I. Speak to the court!
And therefore be a patience, and he shall go to him.
Here, and th
--- EPOCH 26/50 ---
train_batch (Avg. Loss 1.157, Accuracy 64.3): 100%|██████████| 348/348 [01:32<00:00, 3.75it/s]
test_batch (Avg. Loss 1.224, Accuracy 61.4): 100%|██████████| 38/38 [00:04<00:00, 9.37it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 26
ACT I. Some servants of the
heart of the Lady Mars.
CLOWN. The people is the better than the d
--- EPOCH 27/50 ---
train_batch (Avg. Loss 1.154, Accuracy 64.4): 100%|██████████| 348/348 [01:32<00:00, 3.76it/s]
test_batch (Avg. Loss 1.223, Accuracy 61.3): 100%|██████████| 38/38 [00:04<00:00, 9.45it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 27
ACT I. Sirnah, as I will
see you say.
CLOWN. What is the matter?
CLOWN. There is no matter t
--- EPOCH 28/50 ---
train_batch (Avg. Loss 1.152, Accuracy 64.5): 100%|██████████| 348/348 [01:33<00:00, 3.74it/s]
test_batch (Avg. Loss 1.220, Accuracy 61.5): 100%|██████████| 38/38 [00:03<00:00, 9.54it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 28
ACT I. Sir, I am sure you
will see your sister.
FALSTAFF. Well, do you see the world?
CLOWN.
--- EPOCH 29/50 ---
train_batch (Avg. Loss 1.149, Accuracy 64.5): 100%|██████████| 348/348 [01:30<00:00, 3.85it/s]
test_batch (Avg. Loss 1.216, Accuracy 61.5): 100%|██████████| 38/38 [00:03<00:00, 10.16it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 29
ACT I. So dear the rest of the
thousand of his body.
PORTIA. We will do what I would not have
--- EPOCH 30/50 ---
train_batch (Avg. Loss 1.146, Accuracy 64.6): 100%|██████████| 348/348 [01:28<00:00, 3.95it/s]
test_batch (Avg. Loss 1.218, Accuracy 61.4): 100%|██████████| 38/38 [00:03<00:00, 9.51it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 30
ACT I. Sir, a man of
my countryman.
LUCIO. Ay, and I will not fortune will not be a woman. A w
--- EPOCH 31/50 ---
train_batch (Avg. Loss 1.144, Accuracy 64.7): 100%|██████████| 348/348 [01:32<00:00, 3.77it/s]
test_batch (Avg. Loss 1.213, Accuracy 61.7): 100%|██████████| 38/38 [00:03<00:00, 9.59it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 31
ACT I. Sir John, if you will
with your heart, I would not, my lord.
LUCIUS. Why, hark you, goo
--- EPOCH 32/50 ---
train_batch (Avg. Loss 1.141, Accuracy 64.7): 100%|██████████| 348/348 [01:31<00:00, 3.82it/s]
test_batch (Avg. Loss 1.210, Accuracy 61.7): 100%|██████████| 38/38 [00:03<00:00, 9.68it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 32
ACT I. Sir,
What can I know you are not with your father?
You have a cursed and of war with
--- EPOCH 33/50 ---
train_batch (Avg. Loss 1.139, Accuracy 64.8): 100%|██████████| 348/348 [01:30<00:00, 3.83it/s]
test_batch (Avg. Loss 1.209, Accuracy 61.7): 100%|██████████| 38/38 [00:03<00:00, 9.71it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 33
ACT I. Sir John.
SEBASTIAN. Now, cousin, I will not be a contented to you.
If they may bear th
--- EPOCH 34/50 ---
train_batch (Avg. Loss 1.136, Accuracy 64.9): 100%|██████████| 348/348 [01:30<00:00, 3.83it/s]
test_batch (Avg. Loss 1.207, Accuracy 61.8): 100%|██████████| 38/38 [00:03<00:00, 9.58it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 34
ACT I.
Exeunt
SCENE III.
The stat
--- EPOCH 35/50 ---
train_batch (Avg. Loss 1.135, Accuracy 64.9): 100%|██████████| 348/348 [01:30<00:00, 3.84it/s]
test_batch (Avg. Loss 1.208, Accuracy 61.8): 100%|██████████| 38/38 [00:03<00:00, 9.72it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 35
ACT I.
Exeunt
SCENE III.
Before the
--- EPOCH 36/50 ---
train_batch (Avg. Loss 1.133, Accuracy 65.0): 100%|██████████| 348/348 [01:30<00:00, 3.84it/s]
test_batch (Avg. Loss 1.207, Accuracy 61.8): 100%|██████████| 38/38 [00:03<00:00, 9.60it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 36
ACT I. And I will shake thee
to say the brain of man should give him a time.
FALSTAFF. He was
--- EPOCH 37/50 ---
train_batch (Avg. Loss 1.131, Accuracy 65.0): 100%|██████████| 348/348 [01:30<00:00, 3.85it/s]
test_batch (Avg. Loss 1.204, Accuracy 61.9): 100%|██████████| 38/38 [00:03<00:00, 9.72it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 37
ACT I. A strange poor
friend.
CLOTEN. Then the world are the sweetest men to our pretty and bo
--- EPOCH 38/50 ---
train_batch (Avg. Loss 1.129, Accuracy 65.1): 100%|██████████| 348/348 [01:30<00:00, 3.84it/s]
test_batch (Avg. Loss 1.202, Accuracy 62.0): 100%|██████████| 38/38 [00:03<00:00, 9.71it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 38
ACT I. Speak.
OTHELLO. With the rest; I have no present present death,
And in the world have s
--- EPOCH 39/50 ---
train_batch (Avg. Loss 1.127, Accuracy 65.1): 100%|██████████| 348/348 [01:30<00:00, 3.85it/s]
test_batch (Avg. Loss 1.198, Accuracy 62.0): 100%|██████████| 38/38 [00:03<00:00, 9.71it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 39
ACT I. My lord,
What says your Grace to say you are so far?
The man hath made the party of y
--- EPOCH 40/50 ---
train_batch (Avg. Loss 1.126, Accuracy 65.1): 100%|██████████| 348/348 [01:30<00:00, 3.87it/s]
test_batch (Avg. Loss 1.199, Accuracy 62.1): 100%|██████████| 38/38 [00:03<00:00, 10.11it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 40
ACT I. Stand to him.
SIR TOBY. You are a letter from my heart.
MRS. PAGE. What can you see the
--- EPOCH 41/50 ---
train_batch (Avg. Loss 1.125, Accuracy 65.2): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.199, Accuracy 62.0): 100%|██████████| 38/38 [00:03<00:00, 10.14it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 41
ACT I.
Exeunt
SCENE III.
A stree
--- EPOCH 42/50 ---
train_batch (Avg. Loss 1.123, Accuracy 65.2): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.196, Accuracy 62.0): 100%|██████████| 38/38 [00:03<00:00, 10.12it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 42
ACT I. Say, sir, the
most reverend Coriolanus is a subject, as I have seen the
end of him. T
--- EPOCH 43/50 ---
train_batch (Avg. Loss 1.121, Accuracy 65.3): 100%|██████████| 348/348 [01:26<00:00, 4.00it/s]
test_batch (Avg. Loss 1.198, Accuracy 62.1): 100%|██████████| 38/38 [00:03<00:00, 9.81it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 43
ACT I. She came not for him.
And what is this? What do you think I live?
You must not stay a
--- EPOCH 44/50 ---
train_batch (Avg. Loss 1.119, Accuracy 65.3): 100%|██████████| 348/348 [01:27<00:00, 3.98it/s]
test_batch (Avg. Loss 1.193, Accuracy 62.2): 100%|██████████| 38/38 [00:03<00:00, 10.17it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 44
ACT I. Sweet Pandarus.
Exeunt
SCENE III.
The t
--- EPOCH 45/50 ---
train_batch (Avg. Loss 1.118, Accuracy 65.3): 100%|██████████| 348/348 [01:27<00:00, 3.99it/s]
test_batch (Avg. Loss 1.195, Accuracy 62.2): 100%|██████████| 38/38 [00:03<00:00, 10.12it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 45
ACT I. A poor wife.
Exit SERVANT
SERVANT.
--- EPOCH 46/50 ---
train_batch (Avg. Loss 1.116, Accuracy 65.4): 100%|██████████| 348/348 [01:27<00:00, 4.00it/s]
test_batch (Avg. Loss 1.193, Accuracy 62.2): 100%|██████████| 38/38 [00:03<00:00, 10.22it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 46
ACT I. A palace
I had not left the good time to the sea,
And show it you the devil have made
--- EPOCH 47/50 ---
train_batch (Avg. Loss 1.115, Accuracy 65.4): 100%|██████████| 348/348 [01:27<00:00, 3.99it/s]
test_batch (Avg. Loss 1.190, Accuracy 62.2): 100%|██████████| 38/38 [00:03<00:00, 10.10it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 47
ACT I.
Exit
MACBETH. I will give the
--- EPOCH 48/50 ---
train_batch (Avg. Loss 1.115, Accuracy 65.4): 100%|██████████| 348/348 [01:27<00:00, 3.99it/s]
test_batch (Avg. Loss 1.191, Accuracy 62.2): 100%|██████████| 38/38 [00:03<00:00, 10.15it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 48
ACT I. Prithee, get thee gone.
CASSIUS. I will not see this consul.
CASSIUS. I will contend you
--- EPOCH 49/50 ---
train_batch (Avg. Loss 1.113, Accuracy 65.5): 100%|██████████| 348/348 [01:27<00:00, 3.99it/s]
test_batch (Avg. Loss 1.192, Accuracy 62.3): 100%|██████████| 38/38 [00:03<00:00, 10.10it/s]
Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.
The text you generate should “look” like a Shakespeare play: old-style English words and sentence structure, directions for the actors (like “Exit/Enter”), sections (Act I/Scene III) etc. There will be no coherent plot of course, but it should at least seem like a Shakespearean play when not looking too closely. If this is not what you see, go back, debug and/or and re-train.
TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.
from hw3.answers import part1_generation_params
start_seq, temperature = part1_generation_params()
generated_sequence = charnn.generate_from_model(
model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)
print(generated_sequence)
SCENE I.
Exeunt
SCENE II.
The forest
Enter the DUKE OF YORK, and the DUKE OF YORK, and the LORDS
CATESBY. Here comes your lordship.
CASSIO. Exit
SCENE XIII.
Alexandria.
Enter Castle.
CAESAR. What is the man?
MESSENGER. The King is such a thing that I have heard
The world is come to see the sense of them.
The seas and honours of the hour of men
Have seen the princes of the common prince.
I know the devil that I have seen thee not.
And therefore have the word that thou wilt weep
To see thy soul to thee thy father's death.
Then stand by thy son to thy hope, thou dost.
If thou dost speak, thou art a serpent,
That I have seen thy son to this most sweet
That had not struck me with the life of heaven.
If thou hast spoke the strength of this her death,
The sun that loves me to the world that stands,
And then the substance of the seas do look upon
The common souls of heart and honour being.
And then I see thee for thy soul thy face,
And then thou shalt not speak to me again.
Then thou art low and seek to stay the state
That I have been a strong advantage that
The seas and things of this and the devil sound
Shall be a witch to show the world in heaven.
The stocks of Rome are with the state of death,
And then the meanest sport and the remainder
That stands upon the sea that seems to find
The same of his delight. What is the matter?
The court of such a company they stand,
And the subjection of the sea shall bear
The world of hearts, and the received of men
With such a shame and false offenders do.
There is no more than the precore of the world.
The greater than the sweet stand by the world,
And the rest of the world the common spirit
Which he hath been a stranger to the course.
What would you be a woman?
CORIOLANUS. I will not say 'Tis married.'
I know not what to see your Grace to hear
The prince of Barnardine.
CLEOPATRA. He hath been sorry
To change the common poor desires of me.
I shall not suffer them. I have the world
To be the strength of this and the design.
I have seen them all the world the first
And the accomplish of the seas and thee
As thou art fairer than they should discharge.
What shall I say the Duke of Lancaster?
What say you, madam? What a fearful soul
Is this the stars of men and love of love
And bring the world the shame of her desires?
What is the matter?
ARCHBISHOP OF CANTERIUS. I will not say the lords and them are strange.
The more the devil and the men of God,
And the rest from the world will send the world.
What should I say that I may say in heaven?
I would not be a shepherd to the world,
And then the state of this desire is sure
That I have seen thee for the counterfeit.
The sea and man of heaven and the word that should be
That will not see the strength of the deed of her
That was the sun to see them all to hear.
The King hath stand'd his heart and her hand here,
And strikes his father for his honour'd hands,
And then the song of them and his desires
May be the shame of his desires and thoughts.
And so I do not see the lark with me.
Exit
SCENE IV.
A man shall be search. The Duke of Caesar
Enter the KING, and the DUKE OF YORK
KING HENRY. What says my lord?
CASSIUS. I thank you, sir.
CASSIO. I have a thing to speak with you.
Exeunt
SCENE III.
A street. Enter KING HENRY, CAITARUS, and others
MENENIUS. What shall I see your honour to the King?
What are you there?
CLOTEN. I thank you, sir.
CLOWN. What is the matter?
PISTOL. I will see you well.
CLOTEN. I will not speak with you.
DUKE. I will not have them the son of the world.
CRESSIDA. I do not speak a word with you. I have said to the
lords and the world were not to see him with his wits.
CLOWN. I am glad to see the prince and the most promise of the world and
the world will not be so poor a man as the forest of the house
of the plain song.
PAGE. Why, that's a true player. I will be the commonwealth
of the proverb of the world. I have a such a sense as they
are men for the state of the world, and the man is too
to the soldier.
PRINCE JOHN. What a man is the man?
COUNTESS. Why, he hath been a servant to the King, and the sea
and the word of the world.
CLOWN. I think he was a great man that should be the most strange
country that he hath not the sun that we have so much
for him. I will see a fool that I would have her man and
the prince of the more that he had not here to see the
country of his head.
FALSTAFF. What a peace I will think it not?
FALSTAFF. What a summer stander is the man in his head?
CLOWN. I will deliver him to the complexion of the lord.
CLOWN. I will not see her soul to the presence. I will be not
to see your honour.
FALSTAFF. What say you, man? I will see the counterfeit man to
the proverbs of the party of the commons that have been a
strange and a secret word.
FALSTAFF. I am a very worth of the maid, and the law is not
to be a man.
FALSTAFF. I will tell you a woman's son to the people.
FALSTAFF. What shall I do me the court?
MRS. FORD. I will stand to man and the stranger and the soldier
of his house and his beautees and the services of
the heavens and the heart of the wars. I would I were a good
soldier, and there is a good world to see the beggar of a fair and
three hours of his hands.
CLOWN. I have seen her that have been a man that hath been a
father that hath a man to see him a beard of his beard.
He was a man that hath been a man that have been a son of
him, and he shall see her beauty.
CLOWN. Why, then the while is not a strange fool.
FALSTAFF. I do not know thee not.
SIR TOBY. Why, then I say the word is a more than a man.
I am sure they were not the matter. I am a good man and the
conceit of the sun of his particular company and the pretty
fool that he will be black and so far of a seat of mine.
Exit
BASTARD. What a man should be the charge of this?
Where is the house? What say you?
CORIOLANUS. I shall not see them stay.
I have seen the common state of my heart
That I do love thee for the grave and friends.
What say you to the common man of this?
I would not be a death of mine as you
Have been a stranger to the holy sing.
The people will be stronger than the mark
Of the rest do the world and honour of it.
The sea will be the sun to see his hand.
What say you to the prince of England's heart?
And you the love of England's son shall stay
The sun to see the princes of the world,
And then the shadow of the sea was found,
And there is no more fair and beggary
Than the most strange and storm of man and heart
And with the stream of stars and hearts and treason.
And then the senators of the sea doth struck
The last of them and the world's state of heaven,
And the contempt of my dear holy face
Will be the suppress of the courtesy.
The lord of Cainal Parther, the King his hand,
And there and the King's son, and the Duke of York.
What says the man? What say you to the world?
Why do you see the state of this most sour?
I do not know your brother to your honour
To make a stranger to the court of heaven,
And therefore stand upon the seas and states,
And there the seas and tongues of heaven did stand.
And therefore shall you be the stranger heart,
And then the parley of the sea is done.
The sea to the most prince and the more things
That he deserves the seas of the design.
The sun shall be the greater than the chamber.
Exeunt
SCENE II.
A street.
Enter Capulets, and Cart and Claudio.
Fran. The sun shall be a state of heart and heart,
And she hath been the strength of this account.
And therefore shall the morning call him brother,
And then the seat of England and the sea
Of the sun should have been a seat of blood,
And he that shall be strangers of the state
Of the stranger of the grave of the world,
And the continion of the streets of France
And honour him that stands his counterfeit
To the device and the sun of the world,
And then the sea of death is for a shame.
There is no more than they do see the heart.
Then let me see the soldier of thy head,
And thou deceiv'd thy soul to the dear sister.
The sea was made a great and stronger than thee
That thou hast standed to thy heart and thee.
And thou thy father, I will see thy hands.
And thou hast spoken thee to me again,
And then thou art the seas and the delights.
The seas and strokes of service and the state
That makes the stocks of men, and the world said
My father and his son will be at hand.
But when I have the man of this desire,
The state of many thousand of the world,
And the deserving of the world was so,
And there the world is come to see the world.
The better that the sun that looks at him,
And the best streams of heaven was made a dead,
And the best service of the state of heaven
Is the ambition of the world of heaven.
The sea was here and here a man more strong
Than the remainder of the world is serv'd.
I am so right a soul of mine own s
TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.
from cs236781.answers import display_answer
import hw3.answers
Why do we split the corpus into sequences instead of training on the whole text?
display_answer(hw3.answers.part1_q1)
We split the corpus into sequences due to the following reasons:
How is it possible that the generated text clearly shows memory longer than the sequence length?
display_answer(hw3.answers.part1_q2)
The generated text shows memory longer than the sequence length because the output of every epoch is not only dependant on the input, but also on the hidden state which is a result of the previous stages thus it acts as our memory in a way.
Why are we not shuffling the order of batches when training?
display_answer(hw3.answers.part1_q3)
We are not shuffling the order of batches when training, because despite previous models that we learnt on, where the order of the samples wasn't important at all. For new chars the RNN needs to be able to learn from the previous ones, and to be able to see the logical connection between them, in order to create the correct patterns. Thus shuffeling the order will interfere the training process and will result in creating the wrong patterns.
display_answer(hw3.answers.part1_q4)
Your answer:
The temperature, as explained in the exercise is a hyper parameter that controls the variance of the distribution
for the next char in conditioned on the current one and the current state of the model. Low temperature value leads to
a lower variance, thus resulting in maximising the chances of choosing the correct answer. In another words it maximising
the chances of our model to over fit the training data.
When sampling we would like to choose lower temperature, since in that stage, and assuming our model is well trained,
we do want the model to generate the answer it thinks most fit.
Choosing higher values of temperature will lead to higher variance of the next char distribution, which means higher chances for choosing random chars, resulting in more made up words and more spelling mistakes. The model will be more "creative" in a way.
Choosing lower values of temperature will lead to lower variance of the next char distribution, which means more confidence in choosing character which results in better spelling, however in less diversity in the generated text. The model will be more "conservative" in a way.
In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from it's latent space. We'll implement and train a VAE and use it to generate new images.
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile
import numpy as np
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda
Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labels faces of famous individuals.
We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)
However, if you feel adventurous and/or prefer to generate something else, feel free to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.
import cs236781.plot as plot
import cs236781.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
DATA_URL = CUSTOM_DATA_URL
_, dataset_dir = cs236781.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/tal.r/.pytorch-datasets/lfw-bush.zip exists, skipping download. Extracting /home/tal.r/.pytorch-datasets/lfw-bush.zip... Extracted 531 to /home/tal.r/.pytorch-datasets/lfw/George_W_Bush
Create a Dataset object that will load the extraced images:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
im_size = 64
tf = T.Compose([
# Resize to constant spatial dimensions
T.Resize((im_size, im_size)),
# PIL.Image -> torch.Tensor
T.ToTensor(),
# Dynamic range [0,1] -> [-1, 1]
T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])
ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)
OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)
test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])
An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a model with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a model with parameters $\bb{\beta}$).
While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.
We define, in Baysean terminology,
To create our variational decoder we'll further specify:
This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.
Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) = \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder model, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}_{\bb{\alpha}}(\bb{x})$.
To train a VAE model, we maximize the evidence distribution, $p(\bb{X})$ (see question below). The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. Although this expectation is intractable, we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO", shown in the lecture):
$$ \log p(\bb{X}) \ge \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} }\left[ \log p _{\bb{\beta}}(\bb{X} | \bb{z}) \right] - \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{X})\,\left\|\, p(\bb{Z} )\right.\right) $$where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.
Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]
By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as
$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} _{\bb{x}} \left[ \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 \right] + \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{x})\,\left\|\, p(\bb{Z} )\right.\right) \right]. $$Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).
First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input.
import hw3.autoencoder as autoencoder
in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)
h = encoder_cnn(x0)
print(h.shape)
test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
(cnn): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): ReLU()
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(4): ReLU()
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(7): ReLU()
(8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): Conv2d(256, 1024, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(10): ReLU()
(11): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
torch.Size([1, 1024, 4, 4])
Now let's implement the CNN part of the Decoder.
Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced
by your EncoderCNN and output an image of the same dimensions as the Encoder's input was.
This can be a CNN which is like a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc.
Consult the documentation of ConvTranspose2D
to figure out how to reverse your convolutional layers in terms of input and output dimensions. Note that the decoder doesn't have to be exactly the opposite of the encoder and you can experiment with using a different architecture.
TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)
test.assertEqual(x0.shape, x0r.shape)
# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
(cnn): Sequential(
(0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU()
(2): ConvTranspose2d(1024, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU()
(5): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU()
(8): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU()
(11): ConvTranspose2d(64, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
)
)
torch.Size([1, 3, 64, 64])
Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:
\bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
\log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
\end{align}
$$Notice that we model the log of the variance, not the actual variance. The above formulation is proposed in appendix C of the VAE paper.
TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module.
You'll also need to define your parameters in __init__().
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)
z, mu, log_sigma2 = vae.encode(x0)
test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)
print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')
VAE(
(features_encoder): EncoderCNN(
(cnn): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): ReLU()
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(4): ReLU()
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(7): ReLU()
(8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): Conv2d(256, 1024, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(10): ReLU()
(11): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(features_decoder): DecoderCNN(
(cnn): Sequential(
(0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU()
(2): ConvTranspose2d(1024, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU()
(5): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU()
(8): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU()
(11): ConvTranspose2d(64, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
)
)
(mu): Linear(in_features=16384, out_features=2, bias=True)
(log_sigma2): Linear(in_features=16384, out_features=2, bias=True)
(rec): Linear(in_features=2, out_features=16384, bias=True)
)
mu(x0)=[0.34623602, -0.028301135], sigma2(x0)=[1.4154439, 1.3355203]
Let's sample some 2d latent representations for an input image x0 and visualize them.
# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
for i in range(N):
Z[i], _, _ = vae.encode(x0)
ax.scatter(*Z[i].cpu().numpy())
# Should be close to the mu/sigma in the previous block above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
sampled mu tensor([ 0.2664, -0.1077]) sampled sigma2 tensor([2.0443, 1.8637])
Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:
TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module.
You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.
x0r = vae.decode(z)
test.assertSequenceEqual(x0r.shape, x0.shape)
Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.
x0r, mu, log_sigma2 = vae(x0)
test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:
$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2 d_x} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}), $$where $d_z$ is the dimension of the latent space, $d_x$ is the dimension of the input and $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$. This pointwise loss is the quantity that we'll compute and minimize with gradient descent. The first term corresponds to the data-reconstruction loss, while the second term corresponds to the KL-divergence loss. Note that the scaling by $d_x$ is not derived from the original loss formula and was added directly to the pointwise loss just to normalize the data term.
TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.
from hw3.autoencoder import vae_loss
torch.manual_seed(42)
def test_vae_loss():
# Test data
N, C, H, W = 10, 3, 64, 64
z_dim = 32
x = torch.randn(N, C, H, W)*2 - 1
xr = torch.randn(N, C, H, W)*2 - 1
z_mu = torch.randn(N, z_dim)
z_log_sigma2 = torch.randn(N, z_dim)
x_sigma2 = 0.9
loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
test.assertAlmostEqual(loss.item(), 58.3234367, delta=1e-3)
return loss
test_vae_loss()
tensor(58.3234)
The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a isotropic Gaussian prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.
TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)
Time to train!
TODO:
VAETrainer class in the hw3/training.py module. Make sure to implement the checkpoints feature of the Trainer class if you haven't done so already in Part 1.part2_vae_hyperparams() function within the hw3/answers.py module.import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams
torch.manual_seed(42)
# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']
# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test = DataLoader(ds_test, batch_size, shuffle=True)
im_size = ds_train[0][0].shape
# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)
# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)
# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
os.remove(f'{checkpoint_file}.pt')
# Show model and hypers
print(vae)
print(hp)
VAE(
(features_encoder): EncoderCNN(
(cnn): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): ReLU()
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(4): ReLU()
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(7): ReLU()
(8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(10): ReLU()
(11): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(features_decoder): DecoderCNN(
(cnn): Sequential(
(0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU()
(2): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU()
(5): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU()
(8): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU()
(11): ConvTranspose2d(64, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
)
)
(mu): Linear(in_features=8192, out_features=256, bias=True)
(log_sigma2): Linear(in_features=8192, out_features=256, bias=True)
(rec): Linear(in_features=256, out_features=8192, bias=True)
)
{'batch_size': 8, 'h_dim': 512, 'z_dim': 256, 'x_sigma2': 0.001, 'learn_rate': 0.0001, 'betas': (0.9, 0.99)}
TODO:
_final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training. Note that your final submission zip will not include the checkpoints/ folder. This is OK.The images you get should be colorful, with different backgrounds and poses.
import IPython.display
def post_epoch_fn(epoch, train_result, test_result, verbose):
# Plot some samples if this is a verbose epoch
if verbose:
samples = vae.sample(n=5)
fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
IPython.display.display(fig)
plt.close(fig)
if os.path.isfile(f'{checkpoint_file_final}.pt'):
print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
checkpoint_file = checkpoint_file_final
else:
res = trainer.fit(dl_train, dl_test,
num_epochs=200, early_stopping=20, print_every=10,
checkpoints=checkpoint_file,
post_epoch_fn=post_epoch_fn)
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
--- EPOCH 1/200 --- train_batch (Avg. Loss 378.441, Accuracy 0.0): 100%|██████████| 60/60 [00:04<00:00, 12.89it/s] test_batch (Avg. Loss 310.163, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 22.61it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 1
train_batch (Avg. Loss 271.746, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.38it/s] test_batch (Avg. Loss 279.001, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.39it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 2 train_batch (Avg. Loss 256.784, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.35it/s] test_batch (Avg. Loss 269.236, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.82it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 3 train_batch (Avg. Loss 241.228, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.49it/s] test_batch (Avg. Loss 235.050, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.64it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 4 train_batch (Avg. Loss 223.099, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.34it/s] test_batch (Avg. Loss 232.248, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.08it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 5 train_batch (Avg. Loss 220.369, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.99it/s] test_batch (Avg. Loss 226.489, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.90it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 6 train_batch (Avg. Loss 216.046, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.97it/s] test_batch (Avg. Loss 220.262, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.12it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 7 train_batch (Avg. Loss 207.692, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.58it/s] test_batch (Avg. Loss 212.441, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.88it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 8 train_batch (Avg. Loss 205.740, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.17it/s] test_batch (Avg. Loss 213.787, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.05it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 9 train_batch (Avg. Loss 204.209, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.07it/s] test_batch (Avg. Loss 207.395, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.83it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 10 --- EPOCH 11/200 --- train_batch (Avg. Loss 195.968, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.30it/s] test_batch (Avg. Loss 203.684, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.83it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 11
train_batch (Avg. Loss 195.799, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.49it/s] test_batch (Avg. Loss 210.951, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.48it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 12 train_batch (Avg. Loss 197.469, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.39it/s] test_batch (Avg. Loss 208.980, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.88it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 13 train_batch (Avg. Loss 196.723, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.44it/s] test_batch (Avg. Loss 198.220, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.65it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 14 train_batch (Avg. Loss 190.354, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.31it/s] test_batch (Avg. Loss 203.655, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.08it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 15 train_batch (Avg. Loss 191.490, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.60it/s] test_batch (Avg. Loss 197.622, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.62it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 16 train_batch (Avg. Loss 188.953, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.83it/s] test_batch (Avg. Loss 195.662, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.87it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 17 train_batch (Avg. Loss 187.227, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.74it/s] test_batch (Avg. Loss 198.271, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.10it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 18 train_batch (Avg. Loss 187.558, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.30it/s] test_batch (Avg. Loss 195.056, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.88it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 19 train_batch (Avg. Loss 188.268, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.38it/s] test_batch (Avg. Loss 197.045, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.51it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 20 --- EPOCH 21/200 --- train_batch (Avg. Loss 184.920, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.29it/s] test_batch (Avg. Loss 204.343, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.32it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 21
train_batch (Avg. Loss 184.257, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.15it/s] test_batch (Avg. Loss 194.064, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.40it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 22 train_batch (Avg. Loss 180.533, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.66it/s] test_batch (Avg. Loss 208.122, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.82it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 23 train_batch (Avg. Loss 181.246, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.00it/s] test_batch (Avg. Loss 185.558, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.30it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 24 train_batch (Avg. Loss 176.608, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.32it/s] test_batch (Avg. Loss 184.729, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 20.08it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 25 train_batch (Avg. Loss 179.263, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.68it/s] test_batch (Avg. Loss 188.049, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.16it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 26 train_batch (Avg. Loss 177.217, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.71it/s] test_batch (Avg. Loss 193.265, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.87it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 27 train_batch (Avg. Loss 176.785, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.49it/s] test_batch (Avg. Loss 182.511, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 20.05it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 28 train_batch (Avg. Loss 174.390, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.91it/s] test_batch (Avg. Loss 187.278, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.78it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 29 train_batch (Avg. Loss 175.843, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.09it/s] test_batch (Avg. Loss 196.466, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.07it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 30 --- EPOCH 31/200 --- train_batch (Avg. Loss 174.368, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.70it/s] test_batch (Avg. Loss 180.973, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.10it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 31
train_batch (Avg. Loss 173.412, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.71it/s] test_batch (Avg. Loss 177.549, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.51it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 32 train_batch (Avg. Loss 169.593, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.64it/s] test_batch (Avg. Loss 184.496, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.42it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 33 train_batch (Avg. Loss 169.761, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.62it/s] test_batch (Avg. Loss 185.237, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.22it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 34 train_batch (Avg. Loss 169.226, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.94it/s] test_batch (Avg. Loss 188.032, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.91it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 35 train_batch (Avg. Loss 171.287, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.02it/s] test_batch (Avg. Loss 200.552, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.91it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 36 train_batch (Avg. Loss 165.452, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.86it/s] test_batch (Avg. Loss 175.366, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.46it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 37 train_batch (Avg. Loss 164.673, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.60it/s] test_batch (Avg. Loss 174.293, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.65it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 38 train_batch (Avg. Loss 166.362, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.24it/s] test_batch (Avg. Loss 175.893, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.86it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 39 train_batch (Avg. Loss 164.225, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.07it/s] test_batch (Avg. Loss 173.205, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.97it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 40 --- EPOCH 41/200 --- train_batch (Avg. Loss 163.808, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.33it/s] test_batch (Avg. Loss 175.874, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.50it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 41
train_batch (Avg. Loss 162.819, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.48it/s] test_batch (Avg. Loss 178.428, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.65it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 42 train_batch (Avg. Loss 160.397, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.91it/s] test_batch (Avg. Loss 175.774, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.28it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 43 train_batch (Avg. Loss 161.687, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.72it/s] test_batch (Avg. Loss 175.889, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.42it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 44 train_batch (Avg. Loss 160.876, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.98it/s] test_batch (Avg. Loss 176.251, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.58it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 45 train_batch (Avg. Loss 161.720, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.44it/s] test_batch (Avg. Loss 169.995, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.74it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 46 train_batch (Avg. Loss 158.793, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.55it/s] test_batch (Avg. Loss 184.537, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.25it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 47 train_batch (Avg. Loss 158.785, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.22it/s] test_batch (Avg. Loss 169.570, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.13it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 48 train_batch (Avg. Loss 159.183, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.82it/s] test_batch (Avg. Loss 185.143, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.22it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 49 train_batch (Avg. Loss 156.624, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.36it/s] test_batch (Avg. Loss 170.005, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 21.77it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 50 --- EPOCH 51/200 --- train_batch (Avg. Loss 156.917, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.78it/s] test_batch (Avg. Loss 168.191, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.05it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 51
train_batch (Avg. Loss 154.151, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.05it/s] test_batch (Avg. Loss 171.440, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.27it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 52 train_batch (Avg. Loss 153.540, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.75it/s] test_batch (Avg. Loss 167.044, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.49it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 53 train_batch (Avg. Loss 155.697, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.43it/s] test_batch (Avg. Loss 165.932, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 20.15it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 54 train_batch (Avg. Loss 152.420, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.68it/s] test_batch (Avg. Loss 162.760, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.15it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 55 train_batch (Avg. Loss 152.942, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.37it/s] test_batch (Avg. Loss 162.635, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.74it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 56 train_batch (Avg. Loss 150.056, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.92it/s] test_batch (Avg. Loss 160.976, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.69it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 57 train_batch (Avg. Loss 151.676, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.09it/s] test_batch (Avg. Loss 164.252, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.94it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 58 train_batch (Avg. Loss 149.871, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.10it/s] test_batch (Avg. Loss 164.930, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.89it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 59 train_batch (Avg. Loss 154.034, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.75it/s] test_batch (Avg. Loss 171.935, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.23it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 60 --- EPOCH 61/200 --- train_batch (Avg. Loss 147.974, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.88it/s] test_batch (Avg. Loss 160.377, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.66it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 61
train_batch (Avg. Loss 147.308, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.20it/s] test_batch (Avg. Loss 163.800, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.52it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 62 train_batch (Avg. Loss 145.566, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.35it/s] test_batch (Avg. Loss 159.688, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.54it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 63 train_batch (Avg. Loss 147.028, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.44it/s] test_batch (Avg. Loss 164.428, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.97it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 64 train_batch (Avg. Loss 146.307, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.42it/s] test_batch (Avg. Loss 163.828, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.85it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 65 train_batch (Avg. Loss 147.600, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.37it/s] test_batch (Avg. Loss 169.617, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.76it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 66 train_batch (Avg. Loss 144.267, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.37it/s] test_batch (Avg. Loss 159.650, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.40it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 67 train_batch (Avg. Loss 143.395, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.57it/s] test_batch (Avg. Loss 156.554, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.43it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 68 train_batch (Avg. Loss 144.879, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.90it/s] test_batch (Avg. Loss 155.293, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 12.40it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 69 train_batch (Avg. Loss 143.182, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.82it/s] test_batch (Avg. Loss 155.642, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.64it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 70 --- EPOCH 71/200 --- train_batch (Avg. Loss 142.408, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.23it/s] test_batch (Avg. Loss 169.023, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.37it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 71
train_batch (Avg. Loss 142.551, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.57it/s] test_batch (Avg. Loss 163.273, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.14it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 72 train_batch (Avg. Loss 141.786, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.23it/s] test_batch (Avg. Loss 163.001, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 12.37it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 73 train_batch (Avg. Loss 141.497, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.75it/s] test_batch (Avg. Loss 166.502, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.62it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 74 train_batch (Avg. Loss 140.641, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.02it/s] test_batch (Avg. Loss 165.054, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.28it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 75 train_batch (Avg. Loss 139.396, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.24it/s] test_batch (Avg. Loss 179.808, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.54it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 76 train_batch (Avg. Loss 139.150, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.46it/s] test_batch (Avg. Loss 156.288, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 12.20it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 77 train_batch (Avg. Loss 137.969, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.20it/s] test_batch (Avg. Loss 156.190, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.50it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 78 train_batch (Avg. Loss 140.325, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.63it/s] test_batch (Avg. Loss 161.717, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.95it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 79 train_batch (Avg. Loss 140.407, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.72it/s] test_batch (Avg. Loss 158.347, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.02it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 80 --- EPOCH 81/200 --- train_batch (Avg. Loss 138.690, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.04it/s] test_batch (Avg. Loss 151.567, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.28it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 81
train_batch (Avg. Loss 137.377, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.73it/s] test_batch (Avg. Loss 154.750, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.33it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 82 train_batch (Avg. Loss 135.864, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.77it/s] test_batch (Avg. Loss 160.369, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.14it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 83 train_batch (Avg. Loss 136.181, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.07it/s] test_batch (Avg. Loss 154.421, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.55it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 84 train_batch (Avg. Loss 137.157, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.96it/s] test_batch (Avg. Loss 157.613, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.76it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 85 train_batch (Avg. Loss 136.648, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.40it/s] test_batch (Avg. Loss 154.077, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 20.99it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 86 train_batch (Avg. Loss 134.827, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.51it/s] test_batch (Avg. Loss 155.680, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.99it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 87 train_batch (Avg. Loss 136.338, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.07it/s] test_batch (Avg. Loss 161.254, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.76it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 88 train_batch (Avg. Loss 133.605, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.33it/s] test_batch (Avg. Loss 154.448, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.35it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 89 train_batch (Avg. Loss 133.772, Accuracy 0.1): 100%|██████████| 60/60 [00:04<00:00, 12.08it/s] test_batch (Avg. Loss 153.504, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.84it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 90 --- EPOCH 91/200 --- train_batch (Avg. Loss 134.791, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.41it/s] test_batch (Avg. Loss 152.286, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.28it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 91
train_batch (Avg. Loss 133.561, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.65it/s] test_batch (Avg. Loss 149.194, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.64it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 92 train_batch (Avg. Loss 134.101, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.38it/s] test_batch (Avg. Loss 148.030, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.85it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 93 train_batch (Avg. Loss 132.928, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.61it/s] test_batch (Avg. Loss 154.436, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.03it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 94 train_batch (Avg. Loss 132.634, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.38it/s] test_batch (Avg. Loss 151.332, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.76it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 95 train_batch (Avg. Loss 133.423, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.93it/s] test_batch (Avg. Loss 153.596, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.91it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 96 train_batch (Avg. Loss 132.371, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.65it/s] test_batch (Avg. Loss 155.632, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.09it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 97 train_batch (Avg. Loss 132.912, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.66it/s] test_batch (Avg. Loss 144.603, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.82it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 98 train_batch (Avg. Loss 131.589, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.77it/s] test_batch (Avg. Loss 151.581, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.06it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 99 train_batch (Avg. Loss 131.059, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.89it/s] test_batch (Avg. Loss 149.586, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.09it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 100 --- EPOCH 101/200 --- train_batch (Avg. Loss 130.511, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.02it/s] test_batch (Avg. Loss 163.039, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.85it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 101
train_batch (Avg. Loss 130.275, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.30it/s] test_batch (Avg. Loss 149.420, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.42it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 102 train_batch (Avg. Loss 128.844, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.09it/s] test_batch (Avg. Loss 146.956, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.57it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 103 train_batch (Avg. Loss 128.522, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.77it/s] test_batch (Avg. Loss 146.297, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.10it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 104 train_batch (Avg. Loss 129.928, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.71it/s] test_batch (Avg. Loss 149.851, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.35it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 105 train_batch (Avg. Loss 129.420, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.86it/s] test_batch (Avg. Loss 146.350, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.65it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 106 train_batch (Avg. Loss 130.030, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.93it/s] test_batch (Avg. Loss 150.771, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.12it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 107 train_batch (Avg. Loss 128.850, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.19it/s] test_batch (Avg. Loss 152.864, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.84it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 108 train_batch (Avg. Loss 130.930, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.86it/s] test_batch (Avg. Loss 178.068, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.88it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 109 train_batch (Avg. Loss 130.072, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.76it/s] test_batch (Avg. Loss 145.862, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.32it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 110 --- EPOCH 111/200 --- train_batch (Avg. Loss 127.675, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.84it/s] test_batch (Avg. Loss 152.146, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.19it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 111
train_batch (Avg. Loss 128.094, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.73it/s] test_batch (Avg. Loss 145.312, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.10it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 112 train_batch (Avg. Loss 126.716, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.43it/s] test_batch (Avg. Loss 147.944, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.98it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 113 train_batch (Avg. Loss 126.773, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.15it/s] test_batch (Avg. Loss 146.830, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.07it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 114 train_batch (Avg. Loss 125.708, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.04it/s] test_batch (Avg. Loss 144.773, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.40it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 115 train_batch (Avg. Loss 127.472, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.52it/s] test_batch (Avg. Loss 157.419, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.70it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 116 train_batch (Avg. Loss 126.670, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.45it/s] test_batch (Avg. Loss 148.224, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 21.46it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 117 train_batch (Avg. Loss 125.921, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.69it/s] test_batch (Avg. Loss 146.472, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 22.11it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 118 train_batch (Avg. Loss 124.524, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.83it/s] test_batch (Avg. Loss 152.780, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.54it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 119 train_batch (Avg. Loss 125.569, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.98it/s] test_batch (Avg. Loss 146.594, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.60it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 120 --- EPOCH 121/200 --- train_batch (Avg. Loss 129.011, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.56it/s] test_batch (Avg. Loss 149.101, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.62it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 121
train_batch (Avg. Loss 126.026, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.29it/s] test_batch (Avg. Loss 154.550, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.26it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 122 train_batch (Avg. Loss 123.601, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.01it/s] test_batch (Avg. Loss 149.289, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.90it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 123 train_batch (Avg. Loss 124.364, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.79it/s] test_batch (Avg. Loss 146.889, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.34it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 124 train_batch (Avg. Loss 123.667, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.75it/s] test_batch (Avg. Loss 166.806, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 21.12it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 125 train_batch (Avg. Loss 124.010, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.87it/s] test_batch (Avg. Loss 151.296, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.32it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 126 train_batch (Avg. Loss 124.104, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.59it/s] test_batch (Avg. Loss 148.395, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.90it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 127 train_batch (Avg. Loss 124.468, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.76it/s] test_batch (Avg. Loss 147.676, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.25it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 128 train_batch (Avg. Loss 123.256, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.83it/s] test_batch (Avg. Loss 143.264, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.35it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 129 train_batch (Avg. Loss 123.297, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.85it/s] test_batch (Avg. Loss 142.164, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.06it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 130 --- EPOCH 131/200 --- train_batch (Avg. Loss 124.784, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.22it/s] test_batch (Avg. Loss 144.756, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.53it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 131
train_batch (Avg. Loss 122.886, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.84it/s] test_batch (Avg. Loss 144.050, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.77it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 132 train_batch (Avg. Loss 126.254, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.93it/s] test_batch (Avg. Loss 169.624, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.21it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 133 train_batch (Avg. Loss 123.230, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.00it/s] test_batch (Avg. Loss 149.544, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.13it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 134 train_batch (Avg. Loss 121.969, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.38it/s] test_batch (Avg. Loss 156.372, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.43it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 135 train_batch (Avg. Loss 122.117, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.00it/s] test_batch (Avg. Loss 145.229, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.79it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 136 train_batch (Avg. Loss 120.728, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.48it/s] test_batch (Avg. Loss 143.013, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 23.75it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 137 train_batch (Avg. Loss 121.478, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.02it/s] test_batch (Avg. Loss 143.128, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.70it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 138 train_batch (Avg. Loss 121.318, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.53it/s] test_batch (Avg. Loss 148.999, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 12.88it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 139 train_batch (Avg. Loss 120.953, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.91it/s] test_batch (Avg. Loss 151.275, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.79it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 140 --- EPOCH 141/200 --- train_batch (Avg. Loss 122.634, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.16it/s] test_batch (Avg. Loss 160.188, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.30it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 141
train_batch (Avg. Loss 121.600, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.47it/s] test_batch (Avg. Loss 156.122, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.74it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 142 train_batch (Avg. Loss 121.341, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.71it/s] test_batch (Avg. Loss 165.139, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.03it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 143 train_batch (Avg. Loss 119.686, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.97it/s] test_batch (Avg. Loss 143.390, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.94it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 144 train_batch (Avg. Loss 120.514, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.05it/s] test_batch (Avg. Loss 177.553, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.30it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 145 train_batch (Avg. Loss 119.686, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.94it/s] test_batch (Avg. Loss 145.143, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.94it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 146 train_batch (Avg. Loss 119.869, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.45it/s] test_batch (Avg. Loss 143.812, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.73it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 147 train_batch (Avg. Loss 120.544, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.15it/s] test_batch (Avg. Loss 141.932, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.31it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 148 train_batch (Avg. Loss 119.515, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.05it/s] test_batch (Avg. Loss 136.384, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.48it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 149 train_batch (Avg. Loss 118.933, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.02it/s] test_batch (Avg. Loss 142.764, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.48it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 150 --- EPOCH 151/200 --- train_batch (Avg. Loss 119.134, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.35it/s] test_batch (Avg. Loss 147.251, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.37it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 151
train_batch (Avg. Loss 120.307, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.75it/s] test_batch (Avg. Loss 143.631, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.17it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 152 train_batch (Avg. Loss 120.768, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.24it/s] test_batch (Avg. Loss 143.780, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.28it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 153 train_batch (Avg. Loss 118.577, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.62it/s] test_batch (Avg. Loss 141.421, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.08it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 154 train_batch (Avg. Loss 117.720, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.37it/s] test_batch (Avg. Loss 143.787, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.05it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 155 train_batch (Avg. Loss 117.729, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.78it/s] test_batch (Avg. Loss 137.582, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 26.18it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 156 train_batch (Avg. Loss 118.235, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.91it/s] test_batch (Avg. Loss 140.008, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.23it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 157 train_batch (Avg. Loss 117.314, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.28it/s] test_batch (Avg. Loss 138.814, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.14it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 158 train_batch (Avg. Loss 119.040, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.61it/s] test_batch (Avg. Loss 150.831, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.21it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 159 train_batch (Avg. Loss 118.210, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.99it/s] test_batch (Avg. Loss 141.757, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.20it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 160 --- EPOCH 161/200 --- train_batch (Avg. Loss 117.939, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.19it/s] test_batch (Avg. Loss 143.061, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 21.42it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 161
train_batch (Avg. Loss 115.668, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.29it/s] test_batch (Avg. Loss 140.826, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.86it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 162 train_batch (Avg. Loss 117.373, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.63it/s] test_batch (Avg. Loss 147.717, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.49it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 163 train_batch (Avg. Loss 116.348, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.42it/s] test_batch (Avg. Loss 141.208, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 21.71it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 164 train_batch (Avg. Loss 116.428, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.03it/s] test_batch (Avg. Loss 145.861, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.74it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 165 train_batch (Avg. Loss 116.582, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.05it/s] test_batch (Avg. Loss 136.921, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.55it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 166 train_batch (Avg. Loss 117.604, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.77it/s] test_batch (Avg. Loss 140.295, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.30it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 167 train_batch (Avg. Loss 115.003, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.24it/s] test_batch (Avg. Loss 144.077, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.39it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 168 train_batch (Avg. Loss 116.722, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.44it/s] test_batch (Avg. Loss 142.675, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.20it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 169 train_batch (Avg. Loss 114.441, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.01it/s] test_batch (Avg. Loss 146.404, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.81it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 170 --- EPOCH 171/200 --- train_batch (Avg. Loss 115.719, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 11.06it/s] test_batch (Avg. Loss 142.386, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 19.92it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 171
train_batch (Avg. Loss 114.988, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.84it/s] test_batch (Avg. Loss 139.521, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.27it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 172 train_batch (Avg. Loss 116.437, Accuracy 0.1): 100%|██████████| 60/60 [00:05<00:00, 10.18it/s] test_batch (Avg. Loss 142.710, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.80it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 173 train_batch (Avg. Loss 114.734, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.47it/s] test_batch (Avg. Loss 139.330, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.24it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 174 train_batch (Avg. Loss 114.412, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.02it/s] test_batch (Avg. Loss 138.487, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.66it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 175 train_batch (Avg. Loss 114.875, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.80it/s] test_batch (Avg. Loss 146.521, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.47it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 176 train_batch (Avg. Loss 113.636, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.52it/s] test_batch (Avg. Loss 150.273, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.93it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 177 train_batch (Avg. Loss 113.173, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.92it/s] test_batch (Avg. Loss 144.514, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.41it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 178 train_batch (Avg. Loss 114.116, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.94it/s] test_batch (Avg. Loss 150.953, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.05it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 179 train_batch (Avg. Loss 113.395, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.90it/s] test_batch (Avg. Loss 141.208, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.78it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 180 --- EPOCH 181/200 --- train_batch (Avg. Loss 113.704, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.89it/s] test_batch (Avg. Loss 145.290, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.58it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 181
train_batch (Avg. Loss 120.716, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.44it/s] test_batch (Avg. Loss 172.495, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.77it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 182 train_batch (Avg. Loss 118.691, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.27it/s] test_batch (Avg. Loss 142.403, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.70it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 183 train_batch (Avg. Loss 114.099, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.97it/s] test_batch (Avg. Loss 144.786, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.24it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 184 train_batch (Avg. Loss 112.769, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.54it/s] test_batch (Avg. Loss 144.442, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.05it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 185 train_batch (Avg. Loss 111.898, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.85it/s] test_batch (Avg. Loss 142.761, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 18.34it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 186 train_batch (Avg. Loss 113.183, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.66it/s] test_batch (Avg. Loss 144.755, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.49it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 187 train_batch (Avg. Loss 112.361, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.48it/s] test_batch (Avg. Loss 139.458, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.06it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 188 train_batch (Avg. Loss 113.067, Accuracy 0.2): 100%|██████████| 60/60 [00:06<00:00, 9.93it/s] test_batch (Avg. Loss 143.489, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 13.31it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 189 train_batch (Avg. Loss 111.502, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.17it/s] test_batch (Avg. Loss 139.886, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.89it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 190 --- EPOCH 191/200 --- train_batch (Avg. Loss 112.113, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.50it/s] test_batch (Avg. Loss 141.475, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.80it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 191
train_batch (Avg. Loss 111.328, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.59it/s] test_batch (Avg. Loss 140.748, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 16.07it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 192 train_batch (Avg. Loss 111.863, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.21it/s] test_batch (Avg. Loss 141.450, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.92it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 193 train_batch (Avg. Loss 111.504, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.55it/s] test_batch (Avg. Loss 140.497, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 17.48it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 194 train_batch (Avg. Loss 110.522, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.85it/s] test_batch (Avg. Loss 140.230, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.86it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 195 train_batch (Avg. Loss 111.849, Accuracy 0.2): 100%|██████████| 60/60 [00:04<00:00, 12.06it/s] test_batch (Avg. Loss 141.514, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 20.58it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 196 train_batch (Avg. Loss 110.656, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.16it/s] test_batch (Avg. Loss 139.679, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 15.91it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 197 train_batch (Avg. Loss 110.681, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.69it/s] test_batch (Avg. Loss 140.087, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 23.71it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 198 train_batch (Avg. Loss 110.851, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 10.44it/s] test_batch (Avg. Loss 145.021, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.48it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 199 --- EPOCH 200/200 --- train_batch (Avg. Loss 110.700, Accuracy 0.2): 100%|██████████| 60/60 [00:05<00:00, 11.50it/s] test_batch (Avg. Loss 139.847, Accuracy 0.1): 100%|██████████| 7/7 [00:00<00:00, 14.97it/s] *** Saved checkpoint checkpoints/vae.pt at epoch 200
*** Images Generated from best model:
TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.
from cs236781.answers import display_answer
import hw3.answers
What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.
display_answer(hw3.answers.part2_q1)
Sigma hyperparameter controls the influence of the data loss. It serves the model as a fine tunning to how far from the mean can the model go.
On lower values of sigma - the model will generate new data that is close to the dataset. On higher values of sigma - the model will be more flexible and creative, so that the new output will different than the dataset in comparison to the output generated by lower values of sigma.
display_answer(hw3.answers.part2_q2)
The purpose of the reconstruction term in the loss function is to measure how well our decoder can reconstruct an image. This measurement helps our model in the training process so we can increase accuracy to generate an image that resembles the original inputs.
The purpose of the KL divergence term in the loss function is the difference between posterior and prior distributions. It's purpose is to improve the approximation of the posterior distribution to generate better latent space samples.
The KL loss can be seen as a regularization term, which acts on the output of the encoder - the latent space distribution so it will be much more like a normal distribution.
The benefit from using this kind of regularization like any other, is to avoid overfitting, and to get better generation results.
In the formulation of the VAE loss, why do we start by maximizing the evidence distribution, $p(\bb{X})$?
display_answer(hw3.answers.part2_q3)
Maximizing the evidence distribution is crucial for making the output sample as close to the dataset as we can get. Maximizing it will result in decoding output to be as close to the dataset with high probability. So our expectation is that our whole encoding - decoding system will work better if we maximize the evidence distribution.
In the VAE encoder, why do we model the log of the latent-space variance corresponding to an input, $\bb{\sigma}^2_{\bb{\alpha}}$, instead of directly modelling this variance?
display_answer(hw3.answers.part2_q4)
Modeling the log of the latent-space variance instead of using the variance has the following pros:
In this part we will implement and train a generative adversarial network and apply it to the task of image generation.
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile
import numpy as np
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda
We'll use the same data as in Part 2.
But again, you can use a custom dataset, by editing the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.
import cs236781.plot as plot
import cs236781.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
DATA_URL = CUSTOM_DATA_URL
_, dataset_dir = cs236781.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/tal.r/.pytorch-datasets/lfw-bush.zip exists, skipping download. Extracting /home/tal.r/.pytorch-datasets/lfw-bush.zip... Extracted 531 to /home/tal.r/.pytorch-datasets/lfw/George_W_Bush
Create a Dataset object that will load the extraced images:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
im_size = 64
tf = T.Compose([
# Resize to constant spatial dimensions
T.Resize((im_size, im_size)),
# PIL.Image -> torch.Tensor
T.ToTensor(),
# Dynamic range [0,1] -> [-1, 1]
T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])
ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)
OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)
test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])
GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.
In a GAN model, two different neural networks compete against each other: A generator and a discriminator.
The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.
The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$
The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$
These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$
A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:
$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.
We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.
TODO: Implement the Discriminator class in the hw3/gan.py module.
If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.
import hw3.gan as gan
dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)
d0 = dsc(x0)
print(d0.shape)
test.assertSequenceEqual(d0.shape, (1,1))
Discriminator(
(flatten): Linear(in_features=16384, out_features=1, bias=True)
(encoder): Sequential(
(0): Conv2d(3, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(512, 1024, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
)
)
torch.Size([1, 1])
TODO: Implement the Generator class in the hw3/gan.py module.
If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)
z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)
test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
(decoder): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.01)
(3): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): LeakyReLU(negative_slope=0.01)
(6): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): LeakyReLU(negative_slope=0.01)
(9): ConvTranspose2d(128, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(10): Tanh()
)
(linear): Linear(in_features=128, out_features=16384, bias=True)
)
torch.Size([1, 3, 64, 64])
Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$
We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.
GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.
We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.
TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)
y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10
loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)
test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)
Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$
which can also be seen as a cross-entropy term. This corresponds to "fooling" the discriminator; Notice that the gradient of the loss w.r.t $\bb{\gamma}$ using this expression also depends on $\bb{\delta}$.
TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.
from hw3.gan import generator_loss_fn
torch.manual_seed(42)
y_generated = torch.rand(20) * 10
loss = generator_loss_fn(y_generated, data_label=1)
print(loss)
test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-3)
tensor(0.0223)
Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.
There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients (i.e., to be part of the Generator's computation graph).
TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())
samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)
Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.
As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)
TODO:
train_batch function in the hw3/gan.py module.part3_gan_hyperparams() function within the hw3/answers.py module.import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams
torch.manual_seed(42)
# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']
# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape
# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)
# Optimizer
def create_optimizer(model_params, opt_params):
opt_params = opt_params.copy()
optimizer_type = opt_params['type']
opt_params.pop('type')
return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])
# Loss
def dsc_loss_fn(y_data, y_generated):
return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])
def gen_loss_fn(y_generated):
return gan.generator_loss_fn(y_generated, hp['data_label'])
# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
os.remove(f'{checkpoint_file}.pt')
# Show hypers
print(hp)
{'batch_size': 32, 'z_dim': 4, 'data_label': 1, 'label_noise': 0.2, 'discriminator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'betas': (0.5, 0.99)}, 'generator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'betas': (0.5, 0.99)}}
TODO:
save_checkpoint function in the hw3.gan module. You can decide on your own criterion regarding whether to save a checkpoint at the end of each epoch._final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training. Note that your final submission zip will not include the checkpoints/ folder. This is OK.import IPython.display
import tqdm
from hw3.gan import train_batch, save_checkpoint
num_epochs = 100
if os.path.isfile(f'{checkpoint_file_final}.pt'):
print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
num_epochs = 0
gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device)
checkpoint_file = checkpoint_file_final
try:
dsc_avg_losses, gen_avg_losses = [], []
for epoch_idx in range(num_epochs):
# We'll accumulate batch losses and show an average once per epoch.
dsc_losses, gen_losses = [], []
print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')
with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
for batch_idx, (x_data, _) in enumerate(dl_train):
x_data = x_data.to(device)
dsc_loss, gen_loss = train_batch(
dsc, gen,
dsc_loss_fn, gen_loss_fn,
dsc_optimizer, gen_optimizer,
x_data)
dsc_losses.append(dsc_loss)
gen_losses.append(gen_loss)
pbar.update()
dsc_avg_losses.append(np.mean(dsc_losses))
gen_avg_losses.append(np.mean(gen_losses))
print(f'Discriminator loss: {dsc_avg_losses[-1]}')
print(f'Generator loss: {gen_avg_losses[-1]}')
if save_checkpoint(gen, dsc_avg_losses, gen_avg_losses, checkpoint_file):
print(f'Saved checkpoint.')
samples = gen.sample(5, with_grad=False)
fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
IPython.display.display(fig)
plt.close(fig)
except KeyboardInterrupt as e:
print('\n *** Training interrupted by user')
--- EPOCH 1/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.60it/s] Discriminator loss: 1.3687921157654594 Generator loss: 10.597953936632942
--- EPOCH 2/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.64it/s] Discriminator loss: 0.7453777467941537 Generator loss: 7.795999302583582
--- EPOCH 3/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.72it/s] Discriminator loss: 1.4490364465643377 Generator loss: 5.118890317047343
--- EPOCH 4/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.70it/s] Discriminator loss: 0.6758973791318781 Generator loss: 2.919074731714585 Saved checkpoint.
--- EPOCH 5/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.68it/s] Discriminator loss: 1.0727176192928762 Generator loss: 3.156718506532557 Saved checkpoint.
--- EPOCH 6/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.65it/s] Discriminator loss: 1.0488448125474594 Generator loss: 3.760477030978483
--- EPOCH 7/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.61it/s] Discriminator loss: 0.8943980967297274 Generator loss: 3.372200236600988
--- EPOCH 8/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.70it/s] Discriminator loss: 0.8894501819330103 Generator loss: 3.6717150141211117 Saved checkpoint.
--- EPOCH 9/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.68it/s] Discriminator loss: 1.0471870092784656 Generator loss: 3.6660856639637665
--- EPOCH 10/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.67it/s] Discriminator loss: 0.8394302830976599 Generator loss: 3.0810628358055565 Saved checkpoint.
--- EPOCH 11/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.64it/s] Discriminator loss: 1.1382394713513992 Generator loss: 3.5762756501927093
--- EPOCH 12/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.66it/s] Discriminator loss: 0.8539060161394232 Generator loss: 3.7156750875360824
--- EPOCH 13/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.68it/s] Discriminator loss: 0.7390222163761363 Generator loss: 3.9916387165293976
--- EPOCH 14/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.68it/s] Discriminator loss: 0.715043401455178 Generator loss: 4.170513377470129
--- EPOCH 15/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.63it/s] Discriminator loss: 0.8928559589035371 Generator loss: 4.007720189936021
--- EPOCH 16/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.66it/s] Discriminator loss: 0.79575675638283 Generator loss: 4.382399348651662
--- EPOCH 17/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.6906244106152478 Generator loss: 4.487350646187277
--- EPOCH 18/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.53it/s] Discriminator loss: 0.5203405750148437 Generator loss: 4.786657277275534
--- EPOCH 19/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.8714026861331042 Generator loss: 5.292223215103149
--- EPOCH 20/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.56it/s] Discriminator loss: 0.7658758759498596 Generator loss: 5.009173554532668
--- EPOCH 21/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.5990386219585643 Generator loss: 5.089686940698063 Saved checkpoint.
--- EPOCH 22/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.5415830489467172 Generator loss: 5.013177058276008
--- EPOCH 23/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.7687535877613461 Generator loss: 5.266466491362628
--- EPOCH 24/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.62it/s] Discriminator loss: 0.7200641579487744 Generator loss: 4.307774719069986
--- EPOCH 25/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.5928275642149589 Generator loss: 4.9157918621512025 Saved checkpoint.
--- EPOCH 26/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.61it/s] Discriminator loss: 0.5390444901936194 Generator loss: 5.495094663956586
--- EPOCH 27/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.49it/s] Discriminator loss: 0.5126533109475585 Generator loss: 5.150843536152559
--- EPOCH 28/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.41150490282212987 Generator loss: 5.575077225180233
--- EPOCH 29/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.8448021464488086 Generator loss: 5.2216512595905975
--- EPOCH 30/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.36600004224216237 Generator loss: 5.804453653447768
--- EPOCH 31/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.3695596745785545 Generator loss: 6.618302219054279
--- EPOCH 32/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.7486391492626246 Generator loss: 5.469777527977438
--- EPOCH 33/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.4889012881938149 Generator loss: 5.140636724584243
--- EPOCH 34/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.39112267117289934 Generator loss: 6.5575794191921455
--- EPOCH 35/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.7731939564075541 Generator loss: 5.627653626834645
--- EPOCH 36/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.60it/s] Discriminator loss: 0.5035816691815853 Generator loss: 5.266443434883566
--- EPOCH 37/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.479262352866285 Generator loss: 5.772578491884119
--- EPOCH 38/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.54it/s] Discriminator loss: 0.3737713361487669 Generator loss: 6.3007317991817695
--- EPOCH 39/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.3026555991348098 Generator loss: 6.317666306215174
--- EPOCH 40/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.3310277154121329 Generator loss: 7.044432752272662
--- EPOCH 41/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.5907923539771753 Generator loss: 7.445226557114545
--- EPOCH 42/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.4323918714242823 Generator loss: 5.478950360242059
--- EPOCH 43/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.51it/s] Discriminator loss: 0.5324846497353386 Generator loss: 7.051200838649974 Saved checkpoint.
--- EPOCH 44/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.60it/s] Discriminator loss: 0.4420340411803302 Generator loss: 6.703069238101735
--- EPOCH 45/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.254093616762582 Generator loss: 6.8490418265847595 Saved checkpoint.
--- EPOCH 46/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.40692354782539253 Generator loss: 7.793828795937931
--- EPOCH 47/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.3468475508339265 Generator loss: 6.414452636943144
--- EPOCH 48/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.56it/s] Discriminator loss: 0.2644627805360976 Generator loss: 7.290988894069896 Saved checkpoint.
--- EPOCH 49/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.56it/s] Discriminator loss: 0.2720303583671065 Generator loss: 7.990272353677189
--- EPOCH 50/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.20840508151142037 Generator loss: 8.58718594382791
--- EPOCH 51/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.5230155558708836 Generator loss: 8.028647731332217
--- EPOCH 52/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.54it/s] Discriminator loss: 0.2532957009971142 Generator loss: 6.9867063129649445
--- EPOCH 53/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.53it/s] Discriminator loss: 0.21314705250894322 Generator loss: 6.995286717134364 Saved checkpoint.
--- EPOCH 54/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.5035999759155161 Generator loss: 8.131016226375804
--- EPOCH 55/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.25574299671194134 Generator loss: 7.354196324067957
--- EPOCH 56/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.36546818912029266 Generator loss: 7.19663423650405 Saved checkpoint.
--- EPOCH 57/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.59it/s] Discriminator loss: 0.19092888450797865 Generator loss: 7.579022407531738
--- EPOCH 58/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.56it/s] Discriminator loss: 0.11545712687075138 Generator loss: 6.638127944048713 Saved checkpoint.
--- EPOCH 59/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.53it/s] Discriminator loss: 0.5442311093211174 Generator loss: 7.365904401330387
--- EPOCH 60/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.25056172830655293 Generator loss: 6.986950187122121
--- EPOCH 61/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.27748722363920775 Generator loss: 8.2945055400624
--- EPOCH 62/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.22749030957108035 Generator loss: 9.198413372039795
--- EPOCH 63/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.49it/s] Discriminator loss: 0.27401006166987557 Generator loss: 8.26684960197 Saved checkpoint.
--- EPOCH 64/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.4017600419328493 Generator loss: 8.926264987272376
--- EPOCH 65/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.50it/s] Discriminator loss: 0.26387319954879146 Generator loss: 7.172547228196088 Saved checkpoint.
--- EPOCH 66/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.54it/s] Discriminator loss: 0.7216299312079654 Generator loss: 7.877212314044728
--- EPOCH 67/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.53it/s] Discriminator loss: 0.2120947763323784 Generator loss: 7.111602418562946 Saved checkpoint.
--- EPOCH 68/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.54it/s] Discriminator loss: 0.23267527932629867 Generator loss: 6.917204239789178 Saved checkpoint.
--- EPOCH 69/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.51it/s] Discriminator loss: 0.10407991652541301 Generator loss: 7.186772963579963
--- EPOCH 70/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.43090079593307834 Generator loss: 8.645269113428453
--- EPOCH 71/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.1617605900939773 Generator loss: 7.665558029623592
--- EPOCH 72/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.5727887942510492 Generator loss: 9.02552342414856
--- EPOCH 73/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.2240339153829743 Generator loss: 7.344865518457749
--- EPOCH 74/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.51it/s] Discriminator loss: 0.12431568498997127 Generator loss: 7.83256282525904 Saved checkpoint.
--- EPOCH 75/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.51it/s] Discriminator loss: 0.3352799747577485 Generator loss: 7.68310967613669
--- EPOCH 76/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.16426550881827578 Generator loss: 7.639457422144273
--- EPOCH 77/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.56it/s] Discriminator loss: 0.31469474075471654 Generator loss: 9.073343571494608
--- EPOCH 78/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.53it/s] Discriminator loss: 0.18737244627931537 Generator loss: 7.298575709847843
--- EPOCH 79/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.629780295140603 Generator loss: 8.561202862683464
--- EPOCH 80/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.51it/s] Discriminator loss: 0.22985811969813177 Generator loss: 5.899905120625215
--- EPOCH 81/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.16521840422030756 Generator loss: 7.634655054877786 Saved checkpoint.
--- EPOCH 82/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.49it/s] Discriminator loss: 0.0818397816927994 Generator loss: 7.378693412331974
--- EPOCH 83/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.60it/s] Discriminator loss: 0.22463832938057535 Generator loss: 8.05539108725155
--- EPOCH 84/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.22021520828061245 Generator loss: 10.063086593852324
--- EPOCH 85/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.23119685842710383 Generator loss: 10.432017859290628
--- EPOCH 86/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.39998093576115723 Generator loss: 9.917597995084876
--- EPOCH 87/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.2209381568519508 Generator loss: 7.547923957600313 Saved checkpoint.
--- EPOCH 88/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.58it/s] Discriminator loss: 0.17278862591175473 Generator loss: 7.190260634702795 Saved checkpoint.
--- EPOCH 89/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.0576223400695359 Generator loss: 8.0687414337607
--- EPOCH 90/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.47it/s] Discriminator loss: 0.31069159967934384 Generator loss: 9.309671738568474
--- EPOCH 91/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.25191090168321834 Generator loss: 9.553768354303697
--- EPOCH 92/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.16455630607464733 Generator loss: 8.312699093538171 Saved checkpoint.
--- EPOCH 93/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.29568952497314005 Generator loss: 8.757419263615327
--- EPOCH 94/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.1422532110968057 Generator loss: 7.2009234147913315 Saved checkpoint.
--- EPOCH 95/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.52it/s] Discriminator loss: 0.07062558670911719 Generator loss: 8.180103666642133 Saved checkpoint.
--- EPOCH 96/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.54it/s] Discriminator loss: 0.3427804368822014 Generator loss: 9.811053360209746
--- EPOCH 97/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.17571305198704495 Generator loss: 9.333539345685173
--- EPOCH 98/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.55it/s] Discriminator loss: 0.11203430023263483 Generator loss: 7.845645680147059 Saved checkpoint.
--- EPOCH 99/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.57it/s] Discriminator loss: 0.23948106213527567 Generator loss: 9.12982988357544
--- EPOCH 100/100 --- 100%|██████████| 17/17 [00:03<00:00, 4.54it/s] Discriminator loss: 0.11756093887721791 Generator loss: 8.469492267159854
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:
TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.
from cs236781.answers import display_answer
import hw3.answers
Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?
display_answer(hw3.answers.part3_q1)
The GAN as we saw consists out of 2 parts - discriminator and generator. During the training process of both of them we generate fake data using the generator and feed it into the discriminator along with real data in order to discriminate between them. The discriminator simply behaves as a classifier and during that phase we keep the generator constant (we don't keep the generator gradients) in order to "help" the discriminator to converge and to allow it to learn the generator flaws. During the generator training phase we need to keep of course its gradients in each step, but also we keep the discriminator gradients as we want to backpropagate all the way from very end of the entire model, which is the discriminator output.
When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?
What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?
display_answer(hw3.answers.part3_q2)
Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?
display_answer(hw3.answers.part3_q3)
First, while the VAE and the GAN are similiar in their abilities to create new generated data, similiar to existing real one, they are very different in the approach they both take to achieve that goal. While the VAE learning approach is to compress the data correctly in order to be able to reconstruct it later, meaning it focuses directly on the data. The GAN approach is to train a "cop" in a way so it will be able to distinguish between real and fake data, along with training a "counterfeiter" so it will be able to fool the cop and vice versa. Meaning, its approach is to create a competition between the two adversaries so the focusing on the data is indirect in some sense. We can see that in the VAE we got slightly better results due to the reasons above. Also we can see that because the VAE focuses directly on the pictures, it also learned to distinguish meaningful areas such as the face part and to ignore the background parts. Thats why the background parts are more blurry in the VAE and the face is much more precise, while in the GAN there is much more similiarity (for better, but also mostly for worse) in the different parts of the generated pictures. Another possible reason for the poorer results in the GAN (compared to the VAE) since its training process is hard. The generator and the discriminator are constantly trying to improve on the cost of each other so we can think at it as trying to shoot at a moving target rather then the typical and constant training process of models we learned so far, including the VAE among them.